import torch
import numpy as np
import torch.nn.functional as F
import os
from transformers import AutoTokenizer, AutoModel
from model.modeling_llada import LLaDAModelLM

def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise

def compute_block_length(
    logits,                
    predicted_tokens,      
    prompt,                
    gen_length,       
    generated_length,
    default_block_length,
    newline_token_id= 198, 
    confidence_threshold=float('inf')
):
    # 
    prompt_length = prompt.shape[1]
    block_start = prompt_length + generated_length
    remaining = gen_length - generated_length
    window_length = int(0.25 * gen_length)
    window_length = min(window_length, remaining)

    decode_tokens = predicted_tokens[0, block_start:block_start + window_length]
    nl_mask = (decode_tokens == newline_token_id)

    if not torch.any(nl_mask):
        # print(f"No newline found, block_length: {default_block_length} (default_block_length)")
        return min(default_block_length, remaining)

    positions_rel = torch.nonzero(nl_mask).squeeze(-1)
    positions = block_start + positions_rel
    nl_logits = logits[0, positions, newline_token_id]
    lse = torch.logsumexp(logits[0, positions, :], dim=-1)
    confidences = torch.exp(nl_logits - lse)

    max_confidence, idx = torch.max(confidences, dim=0)
    max_confidence = max_confidence.item()
    best_pos = positions[idx].item()

    if max_confidence >= confidence_threshold:
        block_length = best_pos - block_start + 1
        # answer_pos = best_pos - prompt_length
        # print(f"Found newline at answer position {answer_pos} (absolute: {best_pos}) with confidence {max_confidence:.6f}, block_length: {block_length}")
    else:
        block_length = min(default_block_length, remaining)
        # print(f"Confidence {max_confidence:.6f} < {confidence_threshold}, block_length: {block_length} (default)")

    return block_length

@ torch.no_grad()
def generate_semantic(model, prompt, steps=128, gen_length=128, init_block_length=128, temperature=0.,
            remasking='low_confidence', mask_id=126336, threshold=None, factor=None, confidence_threshold=float('inf')):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        init_block_length: Initial block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
        threshold: Threshold for top-k sampling.
        factor: Factor for dynamic top-k sampling.
        confidence_threshold: Confidence threshold for block length prediction.
    '''
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    if threshold is None and factor is None: 
        assert confidence_threshold == float('inf'), "Currently does not support top-k sampling with adaptive block length"
        assert gen_length % init_block_length == 0, "gen_length must be divisible by init_block_length for fixed block length"
        num_blocks = gen_length // init_block_length
        assert steps % num_blocks == 0, "steps must be divisible by num_blocks for top-k sampling"
        steps = steps // num_blocks
    
    generated_length = 0
    nfe_history = []  
    block_history = []  # Track block lengths like in generate_block.py
    while generated_length < gen_length: 
        nfe = 0

        output = model(x)
        logits = output.logits
        logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
        predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
        nfe += 1
        
        block_length = compute_block_length(logits, predicted_tokens, prompt, gen_length, generated_length, init_block_length, confidence_threshold=confidence_threshold)
        
        # Add block length to history like in generate_block.py
        block_history.append(block_length)
        
        current_block_start = prompt.shape[1] + generated_length
        current_block_end = current_block_start + block_length
        generated_length += block_length
        # print(f"Update geneated length from {generated_length - block_length} to {generated_length}")
        
        # only allow transfer tokens in current block
        mask_index = (x == mask_id)
        mask_index[:, current_block_end:] = 0
        
        if factor is None:
            x0, transfer_index = get_transfer_index(logits, predicted_tokens, remasking, mask_index, x, None, threshold)
        else:
            x0, transfer_index = get_transfer_index_dynamic(logits, predicted_tokens, remasking, mask_index, x, None, factor)
        x[transfer_index] = x0[transfer_index]

        while True:
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                break
            mask_index = (x == mask_id)
            mask_index[:, current_block_end:] = 0
            block_output = model(x)
            block_logits = block_output.logits
            block_logits_with_noise = add_gumbel_noise(block_logits, temperature=temperature)
            block_predicted_tokens = torch.argmax(block_logits_with_noise, dim=-1)
            nfe += 1
            if factor is None:
                x0, transfer_index = get_transfer_index(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x, None, threshold)
            else:
                x0, transfer_index = get_transfer_index_dynamic(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x, None, factor)
            x[transfer_index] = x0[transfer_index]
        nfe_history.append(nfe)

    return x, nfe_history, block_history


@ torch.no_grad()
def generate_with_prefix_cache(model, prompt, steps=128, gen_length=128, init_block_length=128, temperature=0.,
             remasking='low_confidence', mask_id=126336, threshold=None, factor=None, confidence_threshold=float('inf')):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
        confidence_threshold: Confidence threshold for block length prediction.
    '''
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    if threshold is None and factor is None: 
        assert confidence_threshold == float('inf'), "Currently does not support top-k sampling with adaptive block length"
        assert gen_length % init_block_length == 0, "gen_length must be divisible by init_block_length for fixed block length"
        num_blocks = gen_length // init_block_length
        assert steps % num_blocks == 0, "steps must be divisible by num_blocks for top-k sampling"
        steps = steps // num_blocks
    
    generated_length = 0
    nfe_history = []  
    block_history = []

    while generated_length < gen_length: 
        nfe = 0

        # update full cache in first denoise step of each block
        output = model(x, use_cache=True)
        past_kv = output.past_key_values
        logits = output.logits
        logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
        predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
        nfe += 1
        
        block_length = compute_block_length(logits, predicted_tokens, prompt, gen_length, generated_length, init_block_length, confidence_threshold=confidence_threshold)
        
        # Add block length to history like in generate_block.py
        block_history.append(block_length)
        
        current_block_start = prompt.shape[1] + generated_length
        current_block_end = current_block_start + block_length
        generated_length += block_length
        # print(f"Update geneated length from {generated_length - block_length} to {generated_length}")

        # only allow transfer tokens in current block
        mask_index = (x == mask_id)
        mask_index[:, current_block_end:] = 0
        
        if factor is None:
            x0, transfer_index = get_transfer_index(logits, predicted_tokens, remasking, mask_index, x, None, threshold)
        else:
            x0, transfer_index = get_transfer_index_dynamic(logits, predicted_tokens, remasking, mask_index, x, None, factor)
        x[transfer_index] = x0[transfer_index]

        new_past_key_values = []
        for i in range(len(past_kv)):
            new_past_key_values.append(())
            for j in range(len(past_kv[i])):
                new_past_key_values[i] += (past_kv[i][j][:, :, :current_block_start],)
        
        past_kv = new_past_key_values

        while True:
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                break
            mask_index = (x[:, current_block_start:] == mask_id)
            mask_index[:, block_length:] = 0
            block_output = model(x[:, current_block_start:], past_key_values=past_kv, use_cache=True)
            block_logits = block_output.logits
            block_logits_with_noise = add_gumbel_noise(block_logits, temperature=temperature)
            block_predicted_tokens = torch.argmax(block_logits_with_noise, dim=-1)
            nfe += 1
            if factor is None:
                x0, transfer_index = get_transfer_index(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x[:, current_block_start:], None, threshold)
            else:
                x0, transfer_index = get_transfer_index_dynamic(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x[:, current_block_start:], None, factor)
            x[:, current_block_start:][transfer_index] = x0[transfer_index]
        nfe_history.append(nfe)

    return x, nfe_history, block_history

@torch.no_grad()
def generate_semantic_dual_cache(model, prompt, steps=128, gen_length=128, init_block_length=128, temperature=0.,
            remasking='low_confidence', mask_id=126336, threshold=None, factor=None, confidence_threshold=float('inf')): 
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        init_block_length: Initial block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
        threshold: Threshold for top-k sampling.
        factor: Factor for dynamic top-k sampling.
        confidence_threshold: Confidence threshold for block length prediction.
    '''
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    if threshold is None and factor is None: 
        assert confidence_threshold == float('inf'), "Currently does not support top-k sampling with adaptive block length"
        assert gen_length % init_block_length == 0, "gen_length must be divisible by init_block_length for fixed block length"
        num_blocks = gen_length // init_block_length
        assert steps % num_blocks == 0, "steps must be divisible by num_blocks for top-k sampling"
        steps = steps // num_blocks
    
    generated_length = 0
    nfe_history = []  
    block_history = []
    
    while generated_length < gen_length: 
        nfe = 0

        # update full cache in first denoising step of each block
        output = model(x, use_cache=True)
        past_kv = output.past_key_values
        logits = output.logits
        logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
        predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
        nfe += 1
        
        block_length = compute_block_length(logits, predicted_tokens, prompt, gen_length, generated_length, init_block_length, confidence_threshold=confidence_threshold)
        
        # Add block length to history like in generate_block.py
        block_history.append(block_length)
        
        current_block_start = prompt.shape[1] + generated_length
        current_block_end = current_block_start + block_length
        generated_length += block_length
        # print(f"Update geneated length from {generated_length - block_length} to {generated_length}")
        
        # only allow transfer tokens in current block
        mask_index = (x == mask_id)
        mask_index[:, current_block_end:] = 0
        
        if factor is None:
            x0, transfer_index = get_transfer_index(logits, predicted_tokens, remasking, mask_index, x, None, threshold)
        else:
            x0, transfer_index = get_transfer_index_dynamic(logits, predicted_tokens, remasking, mask_index, x, None, factor)
        x[transfer_index] = x0[transfer_index]

        replace_position = torch.zeros_like(x, dtype=torch.bool)
        replace_position[:, current_block_start:current_block_end] = 1
        # 2nd and later denoising steps in block
        while True:
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                break
            # reduce mask_index size to [b, block_length]
            mask_index = (x[:, current_block_start:current_block_end] == mask_id)
            block_output = model(x[:, current_block_start:current_block_end], past_key_values=past_kv, use_cache=True, replace_position=replace_position)
            block_logits = block_output.logits
            block_logits_with_noise = add_gumbel_noise(block_logits, temperature=temperature)
            block_predicted_tokens = torch.argmax(block_logits_with_noise, dim=-1)
            nfe += 1
            if factor is None:
                x0, transfer_index = get_transfer_index(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x[:, current_block_start:current_block_end], None, threshold)
            else:
                x0, transfer_index = get_transfer_index_dynamic(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x[:, current_block_start:current_block_end], None, factor)
            x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index]
        nfe_history.append(nfe)

    return x, nfe_history, block_history

@torch.no_grad()
def generate_semantic_lazy_dual_cache(model, prompt, steps=128, gen_length=128, init_block_length=128, temperature=0.,
            remasking='low_confidence', mask_id=126336, threshold=None, factor=None, confidence_threshold=float('inf'), cache_update_interval=128): 
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        init_block_length: Initial block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
        threshold: Threshold for top-k sampling.
        factor: Factor for dynamic top-k sampling.
        confidence_threshold: Confidence threshold for block length prediction.
        cache_update_interval: Cache update interval.
    '''
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    if threshold is None and factor is None: 
        assert confidence_threshold == float('inf'), "Currently does not support top-k sampling with adaptive block length"
        assert gen_length % init_block_length == 0, "gen_length must be divisible by init_block_length for fixed block length"
        num_blocks = gen_length // init_block_length
        assert steps % num_blocks == 0, "steps must be divisible by num_blocks for top-k sampling"
        steps = steps // num_blocks
    
    generated_length = 0
    nfe_history = []  
    block_history = []  # Track block lengths like in generate_block.py
    since_rebuild = float('inf') # will triger cache recompute for first block
    
    while generated_length < gen_length: 
        nfe = 0
        
        do_full_rebuild = since_rebuild >= cache_update_interval
        
        if do_full_rebuild: 
            since_rebuild = 0
            # print(f"Rebuild full cache")
            output = model(x, use_cache=True)
            past_kv = output.past_key_values
            nfe += 1
            
            # only predict block length when doing full rebuild
            logits = output.logits
            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
            block_length = compute_block_length(logits, predicted_tokens, prompt, gen_length, generated_length, init_block_length, confidence_threshold=confidence_threshold)
            
            if generated_length + block_length > gen_length:
                block_length = gen_length - generated_length
            current_block_start = prompt.shape[1] + generated_length
            current_block_end = current_block_start + block_length
            generated_length += block_length
            since_rebuild += block_length
            block_history.append(block_length)
            # print(f"Update block length to {block_length}, generated_length to {generated_length}")
            logits_init = logits[:, current_block_start:current_block_end]
            logits_init_with_noise = add_gumbel_noise(logits_init, temperature=temperature)
            pred_block = torch.argmax(logits_init_with_noise, dim=-1)
        else:
            block_length = init_block_length if generated_length + init_block_length <= gen_length else gen_length - generated_length
            # print(f"Skip full rebuild")
            # use cache from the previous block
            current_block_start = prompt.shape[1] + generated_length
            current_block_end = current_block_start + block_length
            generated_length += block_length
            since_rebuild += block_length
            block_history.append(block_length)
            # print(f"Update block length to {block_length}, generated_length to {generated_length}")
            
            replace_position = torch.zeros_like(x, dtype=torch.bool)
            replace_position[:, current_block_start:current_block_end] = 1
            output = model(x[:, current_block_start:current_block_end], past_key_values=past_kv, use_cache=True, replace_position=replace_position)
            logits_init = output.logits
            logits_init_with_noise = add_gumbel_noise(logits_init, temperature=temperature)
            pred_block = torch.argmax(logits_init_with_noise, dim=-1)
            nfe += 1
        
        # only allow transfer tokens in current block
        mask_index = (x[:, current_block_start:current_block_end] == mask_id)
        
        if factor is None:
            x0, transfer_index = get_transfer_index(logits_init, pred_block, remasking, mask_index, x[:, current_block_start:current_block_end], None, threshold)
        else:
            x0, transfer_index = get_transfer_index_dynamic(logits_init, pred_block, remasking, mask_index, x[:, current_block_start:current_block_end], None, factor)
        x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index]

        replace_position = torch.zeros_like(x, dtype=torch.bool)
        replace_position[:, current_block_start:current_block_end] = 1
        # 2nd and later denoising steps in block
        while True:
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                if since_rebuild < cache_update_interval:
                    output = model(x[:, current_block_start:current_block_end], past_key_values=past_kv, use_cache=True, replace_position=replace_position)
                    past_kv = output.past_key_values
                    nfe += 1
                break
            # reduce mask_index size to [b, block_length]
            mask_index = (x[:, current_block_start:current_block_end] == mask_id)
            block_output = model(x[:, current_block_start:current_block_end], past_key_values=past_kv, use_cache=True, replace_position=replace_position)
            block_logits = block_output.logits
            block_logits_with_noise = add_gumbel_noise(block_logits, temperature=temperature)
            block_predicted_tokens = torch.argmax(block_logits_with_noise, dim=-1)
            nfe += 1
            if factor is None:
                x0, transfer_index = get_transfer_index(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x[:, current_block_start:current_block_end], None, threshold)
            else:
                x0, transfer_index = get_transfer_index_dynamic(block_logits, block_predicted_tokens, remasking, mask_index, 
                                                x[:, current_block_start:current_block_end], None, factor)
            x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index]
        nfe_history.append(nfe)

    return x, nfe_history, block_history

def get_transfer_index(logits, predicted_tokens, remasking, mask_index, x, num_transfer_tokens, threshold=None):
    x0 = predicted_tokens # b, l

    if remasking == 'low_confidence':
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.squeeze(
            torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
        # print(f"x0_p: {x0_p}")
    elif remasking == 'random':
        x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
    else:
        raise NotImplementedError(remasking)
    
    x0 = torch.where(mask_index, x0, x)
    confidence = torch.where(mask_index, x0_p, -np.inf)
    # print(f"confidence: {confidence}")

    transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    if threshold is not None:
        num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
    for j in range(confidence.shape[0]):
        _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
        transfer_index[j, select_index] = True
        if threshold is not None:
            for k in range(1, num_transfer_tokens[j]):
                if confidence[j, select_index[k]] < threshold:
                    transfer_index[j, select_index[k]] = False
    return x0, transfer_index

def get_transfer_index_dynamic(logits, predicted_tokens, remasking, mask_index, x, num_transfer_tokens, factor=1):
    x0 = predicted_tokens # b, 1
    
    if remasking == 'low_confidence':
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.squeeze(
            torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
    elif remasking == 'random':
        x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
    else:
        raise NotImplementedError(remasking)
    
    x0 = torch.where(mask_index, x0, x)
    confidence = torch.where(mask_index, x0_p, -np.inf)

    transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
    
    for j in range(confidence.shape[0]):
        ns=list(range(1,num_transfer_tokens[j]+1))
        es=[factor/(n+1) for n in ns]
        threshs=[1-e for e in es]

        # at least one token is transferred
        threshs[0]=-1
        sorted_confidence=torch.sort(confidence[j][mask_index[j]],dim=-1,descending=True)[0]
        assert len(sorted_confidence)==len(threshs)
        for top_i in range(len(threshs)):
            if sorted_confidence[top_i]<threshs[top_i]:
                break

        if top_i == 0 or top_i == len(threshs)-1:
            top_i+=1

        _, select_index = torch.topk(confidence[j], k=top_i)
        transfer_index[j, select_index] = True

    return x0, transfer_index